import cdtools
from cdtools.tools import plotting as p
from matplotlib import pyplot as plt
from matplotlib import animation
import matplotlib.ticker as ticker
from scipy.io import loadmat, savemat
import numpy as np
import torch as t
from helper_functions import compare_ptycho_to_rpi, compare_rpi_to_rpi, hann_window,  get_circular_lineout, fourier_pad

# This gives the plots a clear background when saved to PDF
plt.rcParams.update({
    "figure.facecolor":  (1.0, 1.0, 1.0, 0.0),  # clear
    "axes.facecolor":    (1.0, 1.0, 1.0, 1.0),  # white
    "savefig.facecolor": (1.0, 1.0, 1.0, 0.0),  # clear
})

#
# A few configuration parameters for the script
#

# If True, will save the raw image output for all the imshows, as well as
# the versions embedded in axes. The raw images are used for the paper
# figures for image data
save_output = True
if save_output:
    from pathlib import Path
    Path('figs').mkdir(exist_ok=True)

# The exposure to start with
start = 4

# The number of exposures to plot
n_exposures = 64

# Which summed images to plot (e.g, a sum of 1 image, one of 2 images, ...)
plot_sums = [1,2,4,12,64]


# We load the raw datasets:
ptycho_filename = 'magnetic_ptycho_lcp.cxi'
rpi_filename = 'magnetic_rpi_lcp.cxi'

def make_window(center, radius):
    return np.s_[center[0] - radius : center[0] + radius,
                 center[1] - radius : center[1] + radius]


#
# Load the ptychography results
#

# This is the Fourier-space padding used for the ptycho reconstruction
ptycho_results = loadmat('results/magnetic_ptycho_lcp_synthesized.mat')

# The center of the RPI FOV, in the ptycho reconstruction
ptycho_center = (455,460) # for 037
#ptycho_center = (526,534) # for 036

# The "radius" (half-height / half-width) of the region to extract from
# the ptycho reconstruction
ptycho_radius = 450
ptycho_window = make_window(ptycho_center, ptycho_radius)


#
# Load the RPI results
#

# This was the total width of the low-resolution object representation
resolution = 200
rpi_results = loadmat(f'results/single_shot_rpi_lcp_{resolution}.mat')

#
# Define a helper function to save image data
#

def save_latest_im(filename):
    im = plt.gca().get_images()[0]
    plt.imsave(filename, im.get_array(), cmap=im.get_cmap())



#
# Load some RPI data to show a raw diffraction pattern
#


electronsperphoton = 60 / 3.6
ADUperelectron =2**16 / 200000 # from the MTE-2048B datasheet
ADUperphoton = ADUperelectron * electronsperphoton
# The full-well capacity 200,000 electrons, and 

print('ADU per electron', ADUperelectron)
print('ADU per photon', ADUperphoton)

rpi_dataset = cdtools.datasets.Ptycho2DDataset.from_cxi('data/' + rpi_filename)
example_pattern = rpi_dataset.patterns[start] / ADUperphoton
example_pattern = t.clamp(example_pattern, min=0)

p.plot_real(t.log10(example_pattern + 1))

if save_output:
    save_latest_im('figs/lcp_RPI_example_pattern.png')
    plt.savefig('figs/lcp_RPI_example_pattern.pdf')

    
total_photons = t.sum(rpi_dataset.patterns[:128], dim=(1,2)) / ADUperphoton
mean_photons = t.mean(total_photons)
print('Mean Photons Per Shot:', mean_photons)
mean_photons_per_det_pixel = mean_photons / (1024**2)
mean_photons_per_res_element = mean_photons / 24672 # divided by SBP
print('Mean Photons Per Pixel:', mean_photons_per_det_pixel)
print('Mean Photons Per Resolution Element:', mean_photons_per_res_element)
print('RPI Number of Pixels:', rpi_results['objs'][0].shape[-1])
print('RPI Pixel Size:', np.abs(rpi_results['basis'][0,1]))
print('Ptycho Pixel Size:', np.abs(ptycho_results['basis'][0,1]))

plt.figure(figsize=(3.75,2.5))
plt.plot(total_photons[start:start+64])
plt.xlabel('Shot Index')
plt.ylabel('Measured Photons')
plt.tight_layout()

if save_output:
    plt.savefig('figs/lcp_Photons_per_shot.pdf')


#
# Normalize the RPI results
#

# We run the comparison script between the ptychography result and
# all of RPI reconstructions, which will center all the RPI results
# to a common location.

single_shot_comparisons = []
summed_ims = []

for i in range(start, start + n_exposures):
    print(f'Comparing image {i} with ptycho')
    rpi_result = {'obj': rpi_results['objs'][i],
                  'basis': rpi_results['basis']}

    comp = compare_ptycho_to_rpi(
        ptycho_results,
        rpi_result,
        ptycho_window,
        window_size_factor=0.8,
        nbins=25)
    single_shot_comparisons.append(comp)

    if i == start:
        summed_ims.append(comp['shifted_rpi_obj'])
    else:
        summed_ims.append(summed_ims[-1] + comp['shifted_rpi_obj'])


        
summed_ims = [im / (idx + 1) for idx, im in enumerate(summed_ims)]

#
# Calculate resolution metrics for all the summed images
#

summed_comparisons = []
for i, im in enumerate(summed_ims):
    print(f'Comparing sum of {i+1} image{"" if i==0 else "s"} with ptycho')
    summed_results = {'obj': im,
                      'basis': rpi_results['basis']}
    summed_comparisons.append(
        compare_ptycho_to_rpi(ptycho_results, summed_results, ptycho_window,
                              window_size_factor=0.8,
                              nbins=25))

#
# Now we plot an example single-shot RPI result
#


crop = 60 # This is tuned manually to match the window_size_factor
to_plot = rpi_results['objs'][0][np.s_[crop:-crop,crop:-crop]]

# We manually remove a small residual phase ramp, which we believe comes
# from a small pointing instability in the probe with a deep subpixel level.
#
# This phase ramp is fundamentally undetermined, and originally was set by
# setting the phase of the original ptychography results to be as flat as
# possible. Because there is no unique result for it in the first case, we
# feel comfortable manually setting it for best visibility of the magnetic
# structures

xs, ys = np.mgrid[:to_plot.shape[0], :to_plot.shape[1]]
phase_ramp = np.exp(-1j * 0.001 * ys)

to_plot = phase_ramp * to_plot / np.mean(np.abs(to_plot))
to_plot *= np.exp(-1j * np.angle(np.sum(to_plot)))
f = plt.figure(dpi=300)
p.plot_colorized(to_plot,basis=rpi_results['basis'], fig=f)
if save_output:
    save_latest_im('figs/lcp_RPI_raw_colorized.png')
    plt.savefig('figs/lcp_RPI_raw_colorized.pdf')
    
f = plt.figure(dpi=300)
to_plot_amp = np.clip(np.abs(to_plot), a_min=0.85, a_max=1.15)
p.plot_amplitude(to_plot_amp, basis=rpi_results['basis'], fig=f)
plt.clim([0.85,1.15])
if save_output:
    save_latest_im('figs/lcp_RPI_raw_amplitude.png')
    plt.savefig('figs/lcp_RPI_raw_amplitude.pdf')

f = plt.figure(dpi=300)
to_plot_phase = np.exp(1j*np.clip(np.angle(to_plot), a_min=-0.1, a_max=0.1))
p.plot_phase(to_plot_phase, basis=rpi_results['basis'], fig=f)
if save_output:
    save_latest_im('figs/lcp_RPI_raw_phase.png')
    plt.savefig('figs/lcp_RPI_raw_phase.pdf')

# Now we plot zoomed images from 

#
# And now we plot amplitude and colorized images of the summed images
#

for i in plot_sums:
    to_plot = summed_comparisons[i-1]['rpi_obj']
    to_plot = phase_ramp * to_plot / np.mean(np.abs(to_plot))
    to_plot *= np.exp(-1j * np.angle(np.sum(to_plot)))
    
    f = plt.figure(dpi=300)
    p.plot_colorized(to_plot, basis=rpi_results['basis'], fig=f)
    if save_output:
        save_latest_im(f'figs/lcp_RPI_sum_{i}_colorized.png')
        plt.savefig(f'figs/lcp_RPI_sum_{i}_colorized.pdf')
        
    f = plt.figure(dpi=300)
    to_plot_amp = np.clip(np.abs(to_plot), a_min=0.85, a_max=1.15)
    p.plot_amplitude(to_plot_amp, basis=rpi_results['basis'], fig=f)
    if save_output:
        save_latest_im(f'figs/lcp_RPI_sum_{i}_amplitude.png')
        plt.savefig(f'figs/lcp_RPI_sum_{i}_amplitude.pdf')
        
    f = plt.figure(dpi=300)
    to_plot_phase = np.exp(1j*np.clip(np.angle(to_plot), a_min=-0.1, a_max=0.1))
    p.plot_phase(to_plot_phase, basis=rpi_results['basis'], fig=f)
    if save_output:
        save_latest_im(f'figs/lcp_RPI_sum_{i}_phase.png')
        plt.savefig(f'figs/lcp_RPI_sum_{i}_phase.pdf')


#
# Now we plot the FRCs and SSNRs
#

frcs = np.array([comp['frc'] for comp in summed_comparisons])
ssnrs = np.array([comp['ssnr'] for comp in summed_comparisons])

frc_fig = plt.figure(figsize=(3.75,2.5))
ssnr_fig= plt.figure(figsize=(3.75,2.5))
max_ssnrs = []
#R_1_ssnrs = []
epss = []

# First we plot the greyed-out single-shot comparions
for comp in single_shot_comparisons:
    plt.figure(frc_fig)
    plt.plot(comp['frc_freqs']*1e-6, comp['frc'], 'grey')
    plt.figure(ssnr_fig)
    plt.semilogy(comp['frc_freqs']*1e-6, comp['ssnr'], 'grey')


for comp in summed_comparisons:
    max_ssnrs.append(np.max(comp['ssnr'][2:]))
    epss.append(comp['eps'])
#    R_1_freq_idx = np.nonzero(comp['frc_freqs'] > 2.5e6)[0][0]
#    # TODO: Is this actually where R=1??
#    R_1_ssnrs.append(comp['ssnr'][R_1_freq_idx])

    
for i in plot_sums:
    comp = summed_comparisons[i-1]
    
    label = f'{i} shot{"" if i==0 else "s"}'
    
    plt.figure(frc_fig)
    plt.plot(comp['frc_freqs']*1e-6, comp['frc'], label=label)
    plt.figure(ssnr_fig)
    plt.semilogy(comp['frc_freqs']*1e-6, comp['ssnr'], label=label)


plt.figure(frc_fig)
plt.grid(True)
plt.xlabel('Full Pitch Spatial Frequency (cycles / um)')
plt.ylabel('FCR (ptycho vs RPI)')
plt.legend()

plt.figure(ssnr_fig)
plt.semilogy(ptycho_results['frc_freqs'][0] * 1e-6, ptycho_results['ssnr'][0], 'k', label='ptycho')
plt.grid(True, which='both')
plt.xlabel('Spatial Frequency (cycles / um)')
plt.ylabel('SSNR')
plt.legend()
plt.grid(True, which='both')
plt.plot(comp['frc_freqs']*1e-6, 0.4142 + 0*comp['ssnr'],'k--')
plt.xlim([0, np.max(comp['frc_freqs'])*1e-6])
plt.tight_layout()

plt.figure(figsize=(3.75,2.5))
plt.plot([idx + 1 for idx in range(len(max_ssnrs))], max_ssnrs)
plt.grid(True)
plt.xlabel('Number of shots')
plt.ylabel('Maximum nonzero SSNR value')
plt.tight_layout()

plt.figure(figsize=(3.75,2.5))
plt.semilogy([idx + 1 for idx in range(len(epss))], epss)
plt.grid(True)
plt.xlabel('Number of shots')
plt.ylabel('Normalized MSE')
plt.tight_layout()

#plt.figure(figsize=(3.75,2.5))
#plt.plot([idx + 1 for idx in range(len(R_1_ssnrs))], R_1_ssnrs)
#plt.grid(True)
#plt.xlabel('Number of shots')
#plt.ylabel('SSNR at R=1')
#plt.tight_layout()

plt.show()
